06ee81
@@ -21,38 +21,95 @@
import java.io.InputStream;
 import java.io.ObjectInputStream;
 import java.io.ObjectStreamClass;
 import java.lang.reflect.Proxy;
+import java.util.HashMap;
 
 import org.apache.camel.CamelContext;
 
 /**
  * This class is copied from the Apache ActiveMQ project.
  */
+@SuppressWarnings("rawtypes")
 public class ClassLoadingAwareObjectInputStream extends ObjectInputStream {
 
     private CamelContext camelContext;
 
+    private static final ClassLoader FALLBACK_CLASS_LOADER =
+        ClassLoadingAwareObjectInputStream.class.getClassLoader();
+
+    /**
+     * Maps primitive type names to corresponding class objects.
+     */
+    private static final HashMap<String, Class> primClasses = new HashMap<String, Class>(8, 1.0F);
+
+    private final ClassLoader inLoader;
+
+    public ClassLoadingAwareObjectInputStream(InputStream in) throws IOException {
+        super(in);
+        inLoader = in.getClass().getClassLoader();
+    }
+
     public ClassLoadingAwareObjectInputStream(CamelContext camelContext, InputStream in) throws IOException {
         super(in);
-        this.camelContext = camelContext;
+        inLoader = camelContext.getApplicationContextClassLoader();
     }
 
-    @Override
+
     protected Class<?> resolveClass(ObjectStreamClass classDesc) throws IOException, ClassNotFoundException {
-        return camelContext.getClassResolver().resolveClass(classDesc.getName());
+        ClassLoader cl = Thread.currentThread().getContextClassLoader();
+        return load(classDesc.getName(), cl, inLoader);
     }
 
-    @Override
     protected Class<?> resolveProxyClass(String[] interfaces) throws IOException, ClassNotFoundException {
-        Class<?>[] cinterfaces = new Class[interfaces.length];
+        ClassLoader cl = Thread.currentThread().getContextClassLoader();
+        Class[] cinterfaces = new Class[interfaces.length];
         for (int i = 0; i < interfaces.length; i++) {
-            cinterfaces[i] = camelContext.getClassResolver().resolveClass(interfaces[i]);
+            cinterfaces[i] = load(interfaces[i], cl);
         }
 
         try {
-            return Proxy.getProxyClass(cinterfaces[0].getClassLoader(), cinterfaces);
+            return Proxy.getProxyClass(cl, cinterfaces);
         } catch (IllegalArgumentException e) {
+            try {
+                return Proxy.getProxyClass(inLoader, cinterfaces);
+            } catch (IllegalArgumentException e1) {
+                // ignore
+            }
+            try {
+                return Proxy.getProxyClass(FALLBACK_CLASS_LOADER, cinterfaces);
+            } catch (IllegalArgumentException e2) {
+                // ignore
+            }
+
             throw new ClassNotFoundException(null, e);
         }
     }
 
+    private Class<?> load(String className, ClassLoader... cl) throws ClassNotFoundException {
+        for (ClassLoader loader : cl) {
+            try {
+                return Class.forName(className, false, loader);
+            } catch (ClassNotFoundException e) {
+                // ignore
+            }
+        }
+        // fallback
+        final Class<?> clazz = (Class<?>) primClasses.get(className);
+        if (clazz != null) {
+            return clazz;
+        } else {
+            return Class.forName(className, false, FALLBACK_CLASS_LOADER);
+        }
+    }
+
+    static {
+        primClasses.put("boolean", boolean.class);
+        primClasses.put("byte", byte.class);
+        primClasses.put("char", char.class);
+        primClasses.put("short", short.class);
+        primClasses.put("int", int.class);
+        primClasses.put("long", long.class);
+        primClasses.put("float", float.class);
+        primClasses.put("double", double.class);
+        primClasses.put("void", void.class);
+    }
 }
